home *** CD-ROM | disk | FTP | other *** search
- // Implementation of NeuroSolutions Momentum component with adaptive step
- // sizes and postprocessor for error control
- //
- // Version 2.0, 17/10/96
- //
- // Programmed by Alexandre Bernardino
- // IST, Lisbon, Portugal
-
- #include "NSDLL.h"
- #include <windows.h>
- #include <limits.h>
-
-
- /* GLOBAL VARIABLES */
-
- int gEpoch=-1;
- NSFloat bestError;
- int ERR_GROW_MORE_THAN_ALLOWED = 0;
- int ERROR_DECREASED = 0;
-
- /*****************************/
- /* Gradient search procedure */
-
- __declspec(dllexport) void performMomentum(
- DLLData *instance, // Pointer to instance data (may be NULL)
- NSFloat *weights, // Pointer to the vector of weights
- int length, // Length of the weight vector
- NSFloat *gradient, // Pointer to the vector of gradients, one for each weight
- NSFloat *step, // Pointer to the learning rate/s
- BOOL individual, // Indicates whether their is one learning rate for all weights (FALSE),
- // or each weight has its own learning rate
- int stepDivisor,// The number each step size should be divided by
- BOOL decayWeights, // TRUE if weight decay is active
- NSFloat decayRate, // Rate to decay weights if weight decay is active
- NSFloat momentum, // Momentum rate for all weights
- NSFloat *delta // Last weight Update
- )
- {
- register int i;
-
- NSFloat *mydata = (NSFloat *)getUserData(instance);
- NSFloat up = getFloatParameter(instance, 0, 0);
- NSFloat down = getFloatParameter(instance, 1, 0);
- NSFloat cut = getFloatParameter(instance, 2, 0);
- //NSFloat maxStep = getFloatParameter(instance, 0, 2);
- //NSFloat minStep = getFloatParameter(instance, 2, 2);
- NSFloat maxStep=1e15f;
- NSFloat minStep=1e-15f;
- NSFloat *lastGradient = mydata;
- NSFloat *bestWeights = mydata+length;
-
- if( !individual )
- {
- // MessageBox(NULL,"Must use individual steps for this gradient search",
- // "Error",MB_OK);
- return;
- }
-
- if( gEpoch < 0 ) // initialization
- {
- for (i=0; i<length; i++) {
- bestWeights[i] = weights[i];
- lastGradient[i] = 0.0f;
- }
- }
-
- if( ERR_GROW_MORE_THAN_ALLOWED )
- {
- for (i=0; i<length; i++) {
- step[i] *= cut;
- if(step[i]<minStep)
- step[i]=minStep;
- delta[i] = 0.0f;
- lastGradient[i] = 0.0f;
- weights[i] = bestWeights[i];
- }
- }
- else
- {
- if( ERROR_DECREASED )
- {
- for (i=0; i<length; i++)
- bestWeights[i] = weights[i];
- }
- for (i=0; i<length; i++) {
- if(lastGradient[i]*gradient[i]>0)
- {
- step[i]=step[i]*up;
- if(step[i]>maxStep)
- step[i]=maxStep;
- }
- else if(lastGradient[i]*gradient[i]<0)
- {
- step[i]=step[i]*down;
- if(step[i]<minStep)
- step[i]=minStep;
- }
- delta[i] = momentum*delta[i] + (step[i]/(NSFloat)stepDivisor)*gradient[i];
- weights[i] += delta[i];
- lastGradient[i] = gradient[i];
- }
- }
- }
-
- /******************************************/
- /* Management of instance data (OPTIONAL) */
-
- __declspec(dllexport) DLLData *allocMomentum(
- DLLData *oldInstance, // Pointer to the last instance if reallocating
- int length, // Length of the weight vector
- BOOL individual // Indicates whether their is one learning rate for all weights (FALSE),
- // or each weight has its own learning rate
- )
- {
- DLLData *instance = allocDLLInstance(oldInstance);
- NSFloat *mydata = (NSFloat *)calloc(length*2, sizeof(NSFloat));
- setUserData(instance, mydata);
- setParameterName(instance, 0, 0, "Up", FALSE);
- setFloatParameter(instance, 0, 0, 1.1f, FALSE);
- setParameterName(instance, 1, 0, "Down", FALSE);
- setFloatParameter(instance, 1, 0, 0.9f, FALSE);
- setParameterName(instance, 2, 0, "Cut", FALSE);
- setFloatParameter(instance, 2, 0, 0.5f, FALSE);
- //setParameterName(instance, 0, 1, "Max.Step", FALSE);
- //setFloatParameter(instance, 0, 1, 1e15f, FALSE);
- //setParameterName(instance, 1, 1, "Min.Step", FALSE);
- //setFloatParameter(instance, 1, 1, 1e-15f, FALSE);
- return instance;
- }
-
- __declspec(dllexport) void freeMomentum(DLLData *instance)
- {
- free(getUserData(instance));
- freeDLLInstance(instance);
- }
-
-
- /********************************************************/
- /***************************/
- /* Activation of component */
- __declspec(dllexport) BOOL performPrePost(
- DLLData *instance, // Pointer to instance data (may be NULL)
- NSFloat *input, // Pointer to the input data
- NSFloat *output, // Pointer to the output data
- int rows, // Number of rows of data
- int cols, // Number of cols of data
- BOOL preprocessor // Flag to indicate whether this is a preprocessor or postprocessor
- )
- {
- int i, length=rows*cols;
- NSFloat tolerance = getFloatParameter(instance, 0, 0);
-
- if( gEpoch < 0 ) // initialization
- {
- bestError = input[0];
- for (i=0; i<length; i++)
- output[i] = bestError;
- return TRUE;
- }
- ERROR_DECREASED = 0;
- ERR_GROW_MORE_THAN_ALLOWED = 0;
- if( input[0] < bestError )
- {
- bestError = input[0];
- ERROR_DECREASED = 1;
- }
- else if( input[0] > bestError*tolerance )
- {
- ERR_GROW_MORE_THAN_ALLOWED = 1;
- }
- for (i=0; i<length; i++)
- output[i] = bestError;
- return TRUE; // Return whether to inject this sample or to call performPrePost with another sample
- }
-
-
- /******************************************/
- /* Management of instance data (OPTIONAL) */
-
- __declspec(dllexport) DLLData *allocPrePost(
- DLLData *oldInstance, // Pointer to the last instance if reallocating
- int *rows, // Number of rows of output data, can be changed to reflect a diffenent number for the input data
- int *cols, // Number of cols of output data, can be changed to reflect a diffenent number for the input data
- BOOL preprocessor // Flag to indicate whether this is a preprocessor or postprocessor
- )
- {
- DLLData *instance = allocDLLInstance(oldInstance);
- setParameterName(instance, 0, 0, "Tolerance", FALSE);
- setFloatParameter(instance, 0, 0, 1.01f, FALSE);
- return instance;
- }
-
- __declspec(dllexport) void freePrePost(DLLData *instance)
- {
- freeDLLInstance(instance);
- }
-
- __declspec(dllexport) void networkReset(DLLData *instance)
- {
- gEpoch=-1;
- ERROR_DECREASED = 0;
- ERR_GROW_MORE_THAN_ALLOWED = 0;
- }
-
- __declspec(dllexport) void epochEnded(DLLData *instance, int epoch)
- {
- gEpoch=epoch;
- }
-
-